Skip to content

Integrate flashinfer mm_mxfp8 in ModelOpt MXFP8#35053

Merged
mgoin merged 2 commits intovllm-project:mainfrom
de-inf:mxfp8_flashinfer
Feb 24, 2026
Merged

Integrate flashinfer mm_mxfp8 in ModelOpt MXFP8#35053
mgoin merged 2 commits intovllm-project:mainfrom
de-inf:mxfp8_flashinfer

Conversation

@danisereb
Copy link
Copy Markdown
Contributor

@danisereb danisereb commented Feb 22, 2026

Purpose

Follow up to PR:
#33786

Flashinfer version was recently updated in vLLM (to v0.6.4).

A new MXFP8 GEMM (CUTLASS) is available - mm_mxfp8:
flashinfer-ai/flashinfer#2464

This PR integrates this GEMM into vLLM (for ModelOpt MXFP8).

Test Plan

Use the following model for testing (used in other related PRs):
https://huggingface.co/nvidia/OpenMath2-Llama3.1-8B

Compare performance (tok/sec) and lm_eval results between the original BF16 model and MXFP8 model.

Test Result

Eval / accuracy

Command:

lm_eval \
  --model vllm \
  --model_args pretrained=$MODEL_PATH,max_model_len=4096,enforce_eager=True,attention_backend=TRITON_ATTN \
  --tasks gsm8k \
  --batch_size auto --limit 400

Results (GPU B200):

# OpenMath2-Llama3.1-8B

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8425|±  |0.0182|
|     |       |strict-match    |     5|exact_match|↑  |0.2250|±  |0.0209|

# OpenMath2-Llama3.1-8B-MXFP8

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8275|±  |0.0189|
|     |       |strict-match    |     5|exact_match|↑  |0.2225|±  |0.0208|

Performance benchmark

Command:

vllm bench throughput --model $MODEL_PATH \
--tensor-parallel-size 1 \
--async-scheduling \
--backend vllm \
--dataset-name random \
--random-prefix-len 0 \
--random-input-len 1024 \
--random-output-len 1024 \
--max-num-seqs 128 \
--num-prompts 1024

Results (GPU B200):

# OpenMath2-Llama3.1-8B (BF16)

Throughput: 13.22 requests/s, 27083.73 total tokens/s, 13541.87 output tokens/s
Total num prompt tokens:  1048576
Total num output tokens:  1048576

# OpenMath2-Llama3.1-8B-MXFP8

Throughput: 19.05 requests/s, 39006.42 total tokens/s, 19503.21 output tokens/s
Total num prompt tokens:  1048576
Total num output tokens:  1048576

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request integrates the FlashInfer mm_mxfp8 GEMM into vLLM for ModelOpt MXFP8 quantization. The implementation includes necessary swizzling logic for weight scales and dynamic quantization for inputs. However, there are critical issues identified: incorrect keyword arguments in the FlashInfer wrapper will cause runtime errors, and hard assertions on minimum layer dimensions will cause crashes for models with small layers. A fallback mechanism to the emulation backend should be implemented for unsupported shapes.

Comment on lines +179 to +189
assert min_dim <= K, (
f"mm_mxfp8 requires K >= {min_dim}, got K={K}. "
f"in_features is too small for mm_mxfp8."
)
assert K % MXFP8_BLOCK_SIZE == 0, (
f"mm_mxfp8 requires K to be divisible by {MXFP8_BLOCK_SIZE}, got K={K}."
)
assert min_dim <= N, (
f"mm_mxfp8 requires N >= {min_dim}, got N={N}. "
f"out_features is too small for mm_mxfp8."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These hard assertions on min_dim=128 for K and N will cause vLLM to crash if the model contains any linear layers with dimensions smaller than 128 (e.g., router gates or small projection layers). Instead of crashing, the implementation should detect unsupported shapes and fall back to the EMULATION backend for those specific layers. Note that this requires ensuring the weight scales are processed correctly (not swizzled) for the fallback backend during the weight loading phase.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer not to fall back to emulation and instead raise an error.

The emulation has lower performance compared to cutlass and users may not notice that fallback was triggered.

The ModelOpt MXFP8 support is new, changes to backend selection logic can be added later as needed.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assertions can be thrown before kernel execution, for example, in post weight processing if we recognize that the model has incompatible shapes for the given kernel backend. If kernel apply is only point of failure / check, it would error out much later, only when the kernel is invoked.

Copy link
Copy Markdown
Contributor Author

@danisereb danisereb Feb 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If @mgoin merges his PR #34664 first (Marlin MXFP8 GEMM) I will align my PR to his.
In that case I'll add a select_mxfp8_linear_backend function that will select cutlass / marlin / emulation backend (fallback to marlin if cutlass is not supported).

Maybe an assert should be used only if the user uses an env-var to force cutlass MXFP8 GEMM (or follow similar logic to existing NVFP4 / FP8 "select_*_backend").

@dosubot
Copy link
Copy Markdown

dosubot bot commented Feb 22, 2026

Related Documentation

Checked 0 published document(s) in 1 knowledge base(s). No updates required.

How did I do? Any feedback?  Join Discord

Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
@danisereb danisereb changed the title Integrate flashinfer mm_mxfp8 (MXFP8 GEMM) Integrate flashinfer mm_mxfp8 (MXFP8 GEMM) in ModelOpt MXFP8 Feb 22, 2026
@danisereb danisereb changed the title Integrate flashinfer mm_mxfp8 (MXFP8 GEMM) in ModelOpt MXFP8 Integrate flashinfer mm_mxfp8 in ModelOpt MXFP8 Feb 22, 2026
@mgoin mgoin added ready ONLY add when PR is ready to merge/full CI is needed quantization labels Feb 22, 2026
@github-project-automation github-project-automation bot moved this to Ready in NVIDIA Feb 24, 2026
@mgoin mgoin merged commit 9609b1f into vllm-project:main Feb 24, 2026
67 checks passed
@github-project-automation github-project-automation bot moved this from Ready to Done in NVIDIA Feb 24, 2026
tom-zju pushed a commit to tom-zju/vllm that referenced this pull request Feb 26, 2026
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
llsj14 pushed a commit to llsj14/vllm that referenced this pull request Mar 1, 2026
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
tunglinwood pushed a commit to tunglinwood/vllm that referenced this pull request Mar 4, 2026
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
askliar pushed a commit to askliar/vllm that referenced this pull request Mar 9, 2026
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Signed-off-by: Andrii Skliar <askliar@nvidia.com>
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
Signed-off-by: Daniel Serebrenik <daserebrenik@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

nvidia quantization ready ONLY add when PR is ready to merge/full CI is needed

Projects

Status: Done

Development

Successfully merging this pull request may close these issues.

3 participants